import time
import torch

torch.cuda.init()

torch.cuda.allot_shared_cache(1*1024*1024*1024)

s = torch.cuda.Stream()
with torch.cuda.stream(s):
    torch.cuda.insert_shared_cache(0, 1*1024*1024*1024)
    a = torch.rand([4200,2]).cuda()
t0 = time.perf_counter()
with torch.cuda.stream(s):
    b = torch.exp(a)
    torch.cuda.synchronize()
t1 = time.perf_counter()

print("time: ", t1 - t0)